package model

import (
	"database/sql"
	"fmt"
	"log"
	"time"
)

type StockList struct {
	ID             int       `json:"id"`
	Code           string    `json:"code"`
	Name           string    `json:"name"`
	Market         string    `json:"market"`
	TotalCapital   float64   `json:"total_capital"`
	CurrentCapital float64   `json:"current_capital"`
	AllNum         int       `json:"all_num"`
	TypeID         string    `json:"type_id"`
	CategoryName   string    `json:"category_name"`
	CreatedAt      time.Time `json:"created_at"`
	UpdatedAt      time.Time `json:"updated_at"`
}

// GetStockListPage 获取指定页数和每页条目数的股票列表数据
func GetStockListPage(page, pageSize int) ([]StockList, error) {
	// 计算偏移量
	offset := (page - 1) * pageSize

	// 执行 SQL 查询
	rows, err := db.Query("SELECT id, code, name, market, totalcapital_val, currcapital_val, all_num, type_id, category_name, created_at, updated_at FROM stock_list ORDER BY id LIMIT ?, ?", offset, pageSize)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var stockList []StockList

	// 遍历查询结果集
	for rows.Next() {
		var createdAt, updatedAt []uint8 // 数据库中的时间字段以 []uint8 类型表示

		var stock StockList
		err := rows.Scan(&stock.ID, &stock.Code, &stock.Name, &stock.Market, &stock.TotalCapital, &stock.CurrentCapital, &stock.AllNum, &stock.TypeID, &stock.CategoryName, &createdAt, &updatedAt)
		if err != nil {
			return nil, err
		}

		// 将 []uint8 类型的数据库时间字段转换为 time.Time 类型
		stock.CreatedAt, err = time.Parse("2006-01-02 15:04:05", string(createdAt))
		if err != nil {
			return nil, fmt.Errorf("error parsing created_at: %v", err)
		}

		stock.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", string(updatedAt))
		if err != nil {
			return nil, fmt.Errorf("error parsing updated_at: %v", err)
		}

		stockList = append(stockList, stock)
	}

	// 检查 rows.Next() 中的错误
	if err := rows.Err(); err != nil {
		return nil, err
	}

	return stockList, nil
}

// Save method to insert or update a stock record in the database
func (s *StockList) SaveStockList() error {
	// Check if the stock with the same code exists in the database
	var count int
	err := db.QueryRow("SELECT COUNT(*) FROM stock_list WHERE code = ?", s.Code).Scan(&count)
	if err != nil {
		log.Printf("Failed to check if stock exists: %v", err)
		return err
	}

	// Prepare query based on whether the stock exists
	var query string
	if count > 0 {
		query = `
			UPDATE stock_list
			SET name = ?, market = ?, totalcapital_val = ?, currcapital_val = ?, all_num = ?, type_id = ?, category_name = ?, updated_at = ?
			WHERE code = ?
		`
	} else {
		query = `
			INSERT INTO stock_list (code, name, market, totalcapital_val, currcapital_val, all_num, type_id, category_name, created_at, updated_at)
			VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
		`
	}

	// Execute the query
	var res sql.Result
	if count > 0 {
		res, err = db.Exec(query, s.Name, s.Market, s.TotalCapital, s.CurrentCapital, s.AllNum, s.TypeID, s.CategoryName, time.Now(), s.Code)
	} else {
		res, err = db.Exec(query, s.Code, s.Name, s.Market, s.TotalCapital, s.CurrentCapital, s.AllNum, s.TypeID, s.CategoryName, time.Now(), time.Now())
	}
	if err != nil {
		log.Printf("Error saving stock: %v", err)
		return err
	}

	// Check the number of rows affected to determine if it was an update or insert
	rowsAffected, err := res.RowsAffected()
	if err != nil {
		log.Printf("Error getting rows affected: %v", err)
		return err
	}
	if count > 0 && rowsAffected > 0 {
		log.Printf("Stock with code %s updated successfully", s.Code)
	} else if rowsAffected > 0 {
		log.Printf("Stock with code %s inserted successfully", s.Code)
	}

	return nil
}

// BatchInsertStocks performs batch insertion of stocks into the database with upsert functionality
func BatchInsertStocks(stocks []StockList) error {
	// Start a transaction
	tx, err := db.Begin()
	if err != nil {
		log.Fatalf("Failed to begin transaction: %v", err)
		return err
	}
	defer tx.Rollback()

	// Prepare the insert statement with upsert
	stmt, err := tx.Prepare(`
		INSERT INTO stock_list (code, name, market, totalcapital_val, currcapital_val, created_at, updated_at)
		VALUES (?, ?, ?, ?, ?, ?, ?)
		ON DUPLICATE KEY UPDATE
			name = VALUES(name),
			market = VALUES(market),
			totalcapital_val = VALUES(totalcapital_val),
			currcapital_val = VALUES(currcapital_val),
			updated_at = VALUES(updated_at)
	`)
	if err != nil {
		log.Fatalf("Failed to prepare statement: %v", err)
		return err
	}
	defer stmt.Close()

	// Insert each stock into the database
	for _, stock := range stocks {
		_, err := stmt.Exec(stock.Code, stock.Name, stock.Market, stock.TotalCapital, stock.CurrentCapital, time.Now(), time.Now())
		if err != nil {
			log.Printf("Error inserting stock with code %s: %v", stock.Code, err)
			return err
		}
	}

	// Commit the transaction
	err = tx.Commit()
	if err != nil {
		log.Fatalf("Error committing transaction: %v", err)
		return err
	}

	log.Printf("Batch insertion of %d stocks successful", len(stocks))
	return nil
}
